{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training Configs\n",
    "You can always set up training configs directly in python scripts or with a yaml file. Refer to TrainingConfig for more API details.\n",
    "\n",
    "# 1. Default Configs\n",
    "You can dump default training configs or write customized training configs to a yaml file.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings #\n",
    "warnings.filterwarnings(\"ignore\") #\n",
    "from towhee.trainer.training_config import dump_default_yaml, TrainingConfig\n",
    "default_config_file = 'default_training_configs.yaml'\n",
    "dump_default_yaml(default_config_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can open default_training_configs.yaml, and you can get the default config yaml structure like this:\n",
    "```yaml\n",
    "train:\n",
    "    output_dir: ./output_dir\n",
    "    overwrite_output_dir: true\n",
    "    eval_strategy: epoch\n",
    "    eval_steps:\n",
    "    batch_size: 8\n",
    "    val_batch_size: -1\n",
    "    seed: 42\n",
    "    epoch_num: 2\n",
    "    dataloader_pin_memory: true\n",
    "    dataloader_drop_last: true\n",
    "    dataloader_num_workers: 0\n",
    "    load_best_model_at_end: false\n",
    "    freeze_bn: false\n",
    "device:\n",
    "    device_str:\n",
    "    sync_bn: false\n",
    "logging:\n",
    "    print_steps:\n",
    "learning:\n",
    "    lr: 5e-05\n",
    "    loss: CrossEntropyLoss\n",
    "    optimizer: Adam\n",
    "    lr_scheduler_type: linear\n",
    "    warmup_ratio: 0.0\n",
    "    warmup_steps: 0\n",
    "callback:\n",
    "    early_stopping:\n",
    "        monitor: eval_epoch_metric\n",
    "        patience: 4\n",
    "        mode: max\n",
    "    model_checkpoint:\n",
    "        every_n_epoch: 1\n",
    "    tensorboard:\n",
    "        log_dir:\n",
    "        comment: ''\n",
    "metrics:\n",
    "    metric: Accuracy\n",
    "```\n",
    "So the yaml file is corresponding to the TrainingConfig instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TrainingConfig(output_dir='./output_dir', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=8, val_batch_size=-1, seed=42, epoch_num=2, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=0, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard={'log_dir': None, 'comment': ''}, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, freeze_bn=False)\n"
     ]
    }
   ],
   "source": [
    "training_configs = TrainingConfig().load_from_yaml(default_config_file)\n",
    "print(training_configs)\n",
    "training_configs.output_dir = 'my_test_output'\n",
    "training_configs.save_to_yaml('my_test_config.yaml')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Open my_test_config.yaml, and you will find `output_dir` is modified:\n",
    "```yaml\n",
    "train:\n",
    "    output_dir: my_test_output\n",
    "```\n",
    "So there are 2 ways to set up the configs. One is using by class `TrainingConfig`, another is to overwrite the yaml file.\n",
    "\n",
    "# 2.Setting by TrainingConfig\n",
    "It's easy to set config using the TrainingConfig class. Just set the fields in TrainingConfig instance.\n",
    "You can get each config field introduction easily by `get_config_help()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- output_dir \n",
      "  default: ./output_dir \n",
      "  metadata_dict: {'help': 'The output directory where the model predictions and checkpoints will be written.', 'category': 'train'} \n",
      "\n",
      "- overwrite_output_dir \n",
      "  default: True \n",
      "  metadata_dict: {'help': 'Overwrite the content of the output directory.Use this to continue training if output_dir points to a checkpoint directory.', 'category': 'train'} \n",
      "\n",
      "- eval_strategy \n",
      "  default: epoch \n",
      "  metadata_dict: {'help': 'The evaluation strategy. It can be `steps`, `epoch`, `eval_epoch` or `no`,', 'category': 'train'} \n",
      "\n",
      "- eval_steps \n",
      "  default: None \n",
      "  metadata_dict: {'help': 'Run an evaluation every X steps.', 'category': 'train'} \n",
      "\n",
      "- batch_size \n",
      "  default: 8 \n",
      "  metadata_dict: {'help': 'Batch size for training.', 'category': 'train'} \n",
      "\n",
      "- val_batch_size \n",
      "  default: -1 \n",
      "  metadata_dict: {'help': 'Batch size for evaluation.', 'category': 'train'} \n",
      "\n",
      "- seed \n",
      "  default: 42 \n",
      "  metadata_dict: {'help': 'Random seed that will be set at the beginning of training.', 'category': 'train'} \n",
      "\n",
      "- epoch_num \n",
      "  default: 2 \n",
      "  metadata_dict: {'help': 'Total number of training epochs to perform.', 'category': 'train'} \n",
      "\n",
      "- dataloader_pin_memory \n",
      "  default: True \n",
      "  metadata_dict: {'help': 'Whether or not to pin memory for DataLoader.', 'category': 'train'} \n",
      "\n",
      "- dataloader_drop_last \n",
      "  default: True \n",
      "  metadata_dict: {'help': 'Drop the last incomplete batch if it is not divisible by the batch size.', 'category': 'train'} \n",
      "\n",
      "- dataloader_num_workers \n",
      "  default: 0 \n",
      "  metadata_dict: {'help': 'Number of subprocesses to use for data loading. Default 0 means that the data will be loaded in the main process. -1 means using all the cpu kernels, it will greatly improve the speed when distributed training.', 'category': 'train'} \n",
      "\n",
      "- lr \n",
      "  default: 5e-05 \n",
      "  metadata_dict: {'help': 'The initial learning rate.', 'category': 'learning'} \n",
      "\n",
      "- metric \n",
      "  default: Accuracy \n",
      "  metadata_dict: {'help': 'The metric to use to compare two different models.', 'category': 'metrics'} \n",
      "\n",
      "- print_steps \n",
      "  default: None \n",
      "  metadata_dict: {'help': 'If None, use the tqdm progress bar, otherwise it will print the logs on the screen every `print_steps`', 'category': 'logging'} \n",
      "\n",
      "- load_best_model_at_end \n",
      "  default: False \n",
      "  metadata_dict: {'help': 'Whether or not to load the best model found during training at the end of training.', 'category': 'train'} \n",
      "\n",
      "- early_stopping \n",
      "  default: <dataclasses._MISSING_TYPE object at 0x7fada7ded370> \n",
      "  metadata_dict: {'help': 'If the metrics is not better than before in several epoch, it will stop the training.', 'category': 'callback'} \n",
      "\n",
      "- model_checkpoint \n",
      "  default: <dataclasses._MISSING_TYPE object at 0x7fada7ded370> \n",
      "  metadata_dict: {'help': 'How many n epoch to save checkpoints', 'category': 'callback'} \n",
      "\n",
      "- tensorboard \n",
      "  default: <dataclasses._MISSING_TYPE object at 0x7fada7ded370> \n",
      "  metadata_dict: {'help': 'Tensorboard.', 'category': 'callback'} \n",
      "\n",
      "- loss \n",
      "  default: CrossEntropyLoss \n",
      "  metadata_dict: {'help': 'Pytorch loss in torch.nn package', 'category': 'learning'} \n",
      "\n",
      "- optimizer \n",
      "  default: Adam \n",
      "  metadata_dict: {'help': 'Pytorch optimizer Class name in torch.optim package', 'category': 'learning'} \n",
      "\n",
      "- lr_scheduler_type \n",
      "  default: linear \n",
      "  metadata_dict: {'help': 'The scheduler type to use.eg. `linear`, `cosine`, `cosine_with_restarts`, `polynomial`, `constant`, `constant_with_warmup`', 'category': 'learning'} \n",
      "\n",
      "- warmup_ratio \n",
      "  default: 0.0 \n",
      "  metadata_dict: {'help': 'Linear warmup over warmup_ratio fraction of total steps.', 'category': 'learning'} \n",
      "\n",
      "- warmup_steps \n",
      "  default: 0 \n",
      "  metadata_dict: {'help': 'Linear warmup over warmup_steps.', 'category': 'learning'} \n",
      "\n",
      "- device_str \n",
      "  default: None \n",
      "  metadata_dict: {'help': 'None -> If there is a cuda env in the machine, it will use cuda:0, else cpu.\\n`cpu` -> Use cpu only.\\n`cuda:2` -> Use the No.2 gpu, the same for other numbers.\\n`cuda` -> Use all available gpu, using data parallel. If you want to use several specified gpus to run, you can specify the environment variable `CUDA_VISIBLE_DEVICES` as the number of gpus you need before running your training script.', 'category': 'device'} \n",
      "\n",
      "- freeze_bn \n",
      "  default: False \n",
      "  metadata_dict: {'help': 'will completely freeze all BatchNorm layers during training.', 'category': 'train'} \n",
      "\n"
     ]
    }
   ],
   "source": [
    "from towhee.trainer.training_config import get_config_help\n",
    "help_dict = get_config_help() # get config field introductions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "You can construct config by the construct function, or then modify you custom value.\n",
    "```python\n",
    "training_configs = TrainingConfig(\n",
    "    xxx='some_value_xxx',\n",
    "    yyy='some_value_yyy'\n",
    ")\n",
    "# or\n",
    "training_configs.aaa='some_value_aaa'\n",
    "training_configs.bbb='some_value_bbb'\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 3.Setting by yaml file\n",
    "Your yaml file can be briefly with just some lines. You need not write the whole setting.\n",
    "```yaml\n",
    "train:\n",
    "    output_dir: my_another_output\n",
    "```\n",
    "A yaml like this also works. Default values will be overwritten if not written.\n",
    "There are some point you should pay attention.\n",
    "- If a value is None in python, no value is required after the colon.\n",
    "- If the value is `True`/`False` in python, it's `true`/`false` in yaml.\n",
    "- If the field is `str` instance in python, no quotation marks required.\n",
    "- If the field value is `dict` instance in python, start another line after the colon, each line after that is each key-value pair info.\n",
    "```yaml\n",
    "    early_stopping:\n",
    "        monitor: eval_epoch_metric\n",
    "        patience: 4\n",
    "        mode: max\n",
    "```\n",
    "equals\n",
    "```python\n",
    "early_stopping = {\n",
    "    'monitor': 'eval_epoch_metric',\n",
    "    'patience': 4,\n",
    "    'mode': 'max'\n",
    "    }\n",
    "```\n",
    "in python.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
